{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9bd01afc",
   "metadata": {},
   "source": [
    "# Nemo Curator Pipeline Example\n",
    "\n",
    "## NeMo Curator Introduction\n",
    "The NeMo Curator is a Python library that consists of a collection of scalable data-mining modules for curating natural language processing (NLP) data for training large language models (LLMs). The modules within the NeMo Data Curator enable NLP researchers to mine high-quality text at scale from massive uncurated web corpora. \n",
    "\n",
    "NeMo Curator includes the following modules to perform data curation:\n",
    "- Data download and Extraction\n",
    "- Language identification and separation\n",
    "- Text reformatting and cleaning\n",
    "- Quality filtering\n",
    "- Document-level deduplication\n",
    "- Multilingual downstream-task decontamination\n",
    "- Distributed Data Classification\n",
    "- Personal identifiable information (PII) redaction\n",
    "\n",
    "NeMo Curator team has perform ablation experiments using Common Crawl dataset to train a 357M GPT-style model to assess the effect of different curation stage on model performance. \n",
    "\n",
    "![alt text](./image/zeroshot_ablations.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b1808ea",
   "metadata": {},
   "source": [
    "## About this notebook\n",
    "\n",
    "\n",
    "This notebook will use **Thai Wikipedia dataset** as example to demonstrate a typical data curation pipeline using NeMo Curator. After running through this script, user will be able to know how to use NDC to download wikipedia data, perform language separation using fasttext, perform GPU based exact deduplication and fuzzy deduplication and use CPU based heuristic filtering. \n",
    "\n",
    "Step description:\n",
    "1. Download and extract data\n",
    "2. Language detection and separation\n",
    "3. GPU based deduplication\n",
    "    1. Exact deduplication\n",
    "    2. Fuzzy deduplication\n",
    "4. Heuristic filtering\n",
    "\n",
    "What is not included:\n",
    "1. Customized downloading\n",
    "2. Classifier filtering\n",
    "3. Downstream-task decontamination\n",
    "4. Distributed data classification with PyTorch models\n",
    "5. Personal identifiable information (PII) redaction \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78537bd7",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "### System Requirements\n",
    "Here is the hardware setting for this notebook\n",
    "\n",
    "**GPU**: NVIDIA A10 24G. \n",
    "\n",
    "**CUDA & Nvidia Drivers**: CUDA 12.2 with Driver 535.154.05\n",
    "\n",
    "**OS**: ubuntu 22.04\n",
    "\n",
    "### Getting NeMo Framework Training Container\n",
    "- Get access to the container via https://developer.nvidia.com/nemo-framework\n",
    "- Set your docker credentials \n",
    "    ```bash\n",
    "    docker login nvcr.io\n",
    "\n",
    "    Username: $oauthtoken\n",
    "    Password: <Your NGC Key>\n",
    "- Get NeMo NeMo Framework Training Container\n",
    "    ```bash\n",
    "    docker pull docker pull nvcr.io/nvidia/nemo:dev.framework\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062b5423",
   "metadata": {},
   "source": [
    "## 0. Env Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8add9bbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: jsonlines in /usr/local/lib/python3.10/dist-packages (4.0.0)\n",
      "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonlines) (23.2.0)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install jsonlines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9940c70d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "from nemo_curator.utils.distributed_utils import get_client,get_num_workers\n",
    "from nemo_curator.utils.script_utils import add_distributed_args\n",
    "from nemo_curator.utils.file_utils import get_all_files_paths_under, separate_by_metadata\n",
    "from nemo_curator.utils.distributed_utils import read_data,write_to_disk\n",
    "from nemo_curator.datasets import DocumentDataset\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "import time\n",
    "import cudf\n",
    "import dask_cudf\n",
    "import dask\n",
    "import numpy as np\n",
    "from dask.distributed import Client, LocalCluster\n",
    "import jsonlines\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fd8a381d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pre_imports():\n",
    "    import cudf \n",
    "\n",
    "def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)):\n",
    "    return add_distributed_args(parser)\n",
    "\n",
    "def check_jsonl_file(file_dir):\n",
    "    for file in os.listdir(file_dir):\n",
    "        if 'jsonl' not in file:\n",
    "            continue\n",
    "        with open(os.path.join(file_dir,file), 'r', encoding='utf-8') as f:\n",
    "            first_line = f.readline()\n",
    "            print(first_line)\n",
    "        break\n",
    "\n",
    "def extract_lines_with_id(file_path,target_list):\n",
    "    with jsonlines.open(file_path) as reader:\n",
    "        for obj in reader:\n",
    "            if obj.get('id') in target_list:\n",
    "                yield obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "589ff257",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/work_dir/tutorials/single_node_tutorial\n"
     ]
    }
   ],
   "source": [
    "cur_dir = os.getcwd()\n",
    "print(cur_dir)\n",
    "data_dir = f\"{cur_dir}/workspace/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "662d505f",
   "metadata": {},
   "source": [
    "## 1. Download\n",
    "In this example, Thai wikipedia data will be downloaded.\n",
    "\n",
    "Here is what happens when function `download_wikipedia()` is called:\n",
    "1. Run `get_wikipedia_urls()` to obtain a list of urls to download .bz2 files for Thai wikipedia data. In this module, we use the base link and the language from user input to formulate a repo links for downloadable wikipedia .bz2 dump files. The formulated link will be `https://dumps.wikimedia.org/<language>wiki`. All the links will be stored in a .txt file. Argument for this function includes:\n",
    "    - `dump_dates`: A date in the string format of 'YYYYMMDD'. It determines which wikipedia snapshot will be downloaded. If not specified, the `latest` snapshot will be downloaded\n",
    "    - `language`: language code of the desired language in lower case. Default value is `en`\n",
    "\n",
    "2. \n",
    "    Run `download_and_extract()` to download and extract contents based on the url list obtained from `get_wikipedia_urls`. User will need to define `downloader`, `extractor` and `iterator` for the dataset. \n",
    "    In this case, `WikipediaDownloader`,`WikipediaIterator` and `WikipediaExtractor` are used.\n",
    "    - `WikipediaDownloader`: Downloads wikipedia dumps file to local folder.\n",
    "    - `WikipediaIterator`: Extracts the .bz2 files and useful content from the base html content.\n",
    "    -  `WikipediaExtractor`: Performs further task specific html content cleaning such as removing media files, removing references/tables etc. and finally yield pure text data which will be store in .jsonl format. \n",
    "    Please refer to `./NeMo-Curator/nemo_curator/download/wikipedia.py` for  detail implementation.\n",
    "    \n",
    "    Argument for this function includes:\n",
    "    - `output_path`: Output path for downloaded and extracted dataset\n",
    "    - `output_type`: Type of output file. Default is .jsonl. User might choose other types such as parquet. In this example, .jsonl will be used\n",
    "    - `language`: See above\n",
    "    - `dump_date`: See above\n",
    "    - `raw_download_dir`: Output path for intermediate downloaded .bz2 file. If not specified, will be downloaded to `output_path`\n",
    "    - `keep_raw_download`: Whether to keep downloaded .bz2 files after extraction. Default is not to keep.\n",
    "    - `force_download`: Whether to restart downloading process if the target .bz2 files are detected under the `raw_download_dir` \n",
    "    - `url_limit`: Number of .bz2 files to be downloaded.\n",
    "\n",
    "The resultant .jsonl for Thai wikipedia will contain the following keys:\n",
    "1. text\n",
    "2. title\n",
    "3. id\n",
    "4. url\n",
    "5. language\n",
    "6. source_id\n",
    "7. file_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "adb59379",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator.download import download_wikipedia"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b56f12a",
   "metadata": {},
   "source": [
    " Start a CPU based Dask cluster. Please modify `n_workers` and `memory_limit` according to your hardware specification. To process TH wikipedia data, it's advised to have `memory_limit` greater than 12GB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e822b5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster = LocalCluster(n_workers=10, processes=True, memory_limit='16GB')\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90cc8b1",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9a03b463",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Output\n",
    "download_base_directory= os.path.join(data_dir,\"wiki_downloads\")\n",
    "download_output_directory = os.path.join(download_base_directory,\"data\")\n",
    "\n",
    "#Relevant parameters\n",
    "dump_date = \"20240201\"\n",
    "language = 'th'\n",
    "url_limit = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f41734a1",
   "metadata": {},
   "source": [
    "Download TH wikipedia data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a45965a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = download_wikipedia(download_output_directory,\n",
    "                   language=language, \n",
    "                   dump_date=dump_date,\n",
    "                   url_limit=url_limit).df.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22b7d5b3",
   "metadata": {},
   "source": [
    "Verify result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "45a69041",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "downloads  thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\n",
      "162164 /nluo_data/NeMo-Curator/tutorials/single_node_tutorial/workspace/wiki_downloads/data/thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\n"
     ]
    }
   ],
   "source": [
    "! ls {download_output_directory}\n",
    "! wc -l  {download_output_directory}/thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "53bdccfd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"text\":\"–\\n\\nป้ายบอกทาง \\n ศาลาประชาคม – กระดานข่าว โครงการ ทรัพยากรและกิจกรรมซึ่งครอบคลุมวิกิพีเดียอย่างกว้างขวาง\\n แผนกช่วยเหลือ – ถามข้อสงสัยเกี่ยวกับการใช้งานวิกิพีเดีย\\n ปุจฉา-วิสัชนา – ถามข้อสงสัยทั่วไปที่คุณอยากรู้\\n ข่าวไซต์ – ประกาศ อัพเดต บทความและข้อมูลข่าวเกี่ยวกับวิกิพีเดียและมูลนิธิวิกิมีเดีย\\n สภากาแฟ – สำหรับอภิปรายเกี่ยวกับวิกิพีเดีย รวมถึงรายงานปัญหาเทคนิคและเสนอนโยบาย\\n Local Embassy – For Wikipedia-related discussion in languages other than Thai.\\n สร้างบทความใหม่ – บทช่วยสอนสำหรับเตรียมพร้อมสร้างบทความแรกของคุณ\\n\\nภาษาอื่น \\n\\n \",\"title\":\"หน้าหลัก\",\"id\":\"1\",\"url\":\"https:\\/\\/th.wikipedia.org\\/wiki\\/%E0%B8%AB%E0%B8%99%E0%B9%89%E0%B8%B2%E0%B8%AB%E0%B8%A5%E0%B8%B1%E0%B8%81\",\"language\":\"th\",\"source_id\":\"thwiki-20240201-thwiki-20240201-pages-articles-multistream.xml.bz2\",\"filename\":\"thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\"}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "check_jsonl_file(download_output_directory)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5f58643",
   "metadata": {},
   "source": [
    "**[Optional]**Close the Dask cluster.You might encounter error such as `Caught signal 11`.It's OK, just rerun the cell again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "0669a830",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client.cluster.close()\n",
    "# client.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43334988",
   "metadata": {},
   "source": [
    "## 2.Language separation and unicode fixing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86ccdc1f",
   "metadata": {},
   "source": [
    "In this section, we will be using a language classification model by fasttext to separate the TH wikipedia dataset based on the document major languages, and we will also fix the unicode in the documents. Detailed steps are:\n",
    "\n",
    "1. Download fasttext model for text language detection\n",
    "2. Construct a filter which uses the downloaded fasttext model to produce a language label to each document. \n",
    "3. Separate each document by the language label. This will create sub-folders for each languages under the output path and the documents under the same language will be output to a .jsonl file in the corresponding sub-folder.\n",
    "4. Load .jsonl file in the folder of desirable language. In this example, `TH` folder will be loaded.\n",
    "5. Apply `UnicodeReformatter` to the data and output the result in .jsonl format. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1e9198e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator import ScoreFilter,Modify\n",
    "from nemo_curator.filters import FastTextLangId\n",
    "from nemo_curator.modifiers import UnicodeReformatter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76e46d2a",
   "metadata": {},
   "source": [
    "**[Optional]** Start a cpu based Dask cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "da3aed8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cluster = LocalCluster(n_workers=10, processes=True, memory_limit='16GB')\n",
    "# client = Client(cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a72479c",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "13b9d2b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input path\n",
    "multilingual_data_path = f\"{download_output_directory}/thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\"\n",
    "\n",
    "# Output path\n",
    "language_base_output_path = os.path.join(data_dir,\"language_sep\")\n",
    "language_data_output_path = os.path.join(language_base_output_path,\"data\")\n",
    "language_separated_output_path = os.path.join(language_data_output_path,\"language\")\n",
    "lang_sep_cleaned_data_output_path = os.path.join(language_data_output_path,\"cleaned\")\n",
    "\n",
    "# Fasttext model path\n",
    "model_path = language_base_output_path\n",
    "\n",
    "# Define desired language\n",
    "target_language = \"TH\"\n",
    "\n",
    "# Define key in output .jsonl files to store the language information\n",
    "language_field = \"language\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8df0322a",
   "metadata": {},
   "source": [
    "Download fasttext model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2666727d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-05-17 03:17:09--  https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin\n",
      "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 99.84.238.181, 99.84.238.154, 99.84.238.162, ...\n",
      "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|99.84.238.181|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 131266198 (125M) [application/octet-stream]\n",
      "Saving to: ‘/nluo_data/NeMo-Curator/tutorials/single_node_tutorial/workspace/language_sep/lid.176.bin.1’\n",
      "\n",
      "lid.176.bin.1       100%[===================>] 125.18M   184MB/s    in 0.7s    \n",
      "\n",
      "2024-05-17 03:17:10 (184 MB/s) - ‘/nluo_data/NeMo-Curator/tutorials/single_node_tutorial/workspace/language_sep/lid.176.bin.1’ saved [131266198/131266198]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin -P {model_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58452516",
   "metadata": {},
   "source": [
    "Apply fasttext model to separate documents by their languages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d8b8c491",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time taken for splitting language:140.04064464569092\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "# Load dataset \n",
    "multilingual_dataset = DocumentDataset.read_json(multilingual_data_path,add_filename=True)\n",
    "\n",
    "#Define Language separation pipeline\n",
    "lang_filter = FastTextLangId(os.path.join(model_path,'lid.176.bin'))\n",
    "language_id_pipeline = ScoreFilter(lang_filter, score_field=language_field, score_type='object')\n",
    "filtered_dataset = language_id_pipeline(multilingual_dataset)\n",
    "\n",
    "# The language separation pipeline will produce a result looks like ['EN',0.96873], we only want to keep the 'EN' label and drop the detailed classifier score\n",
    "filtered_dataset.df[language_field] = filtered_dataset.df[language_field].apply(lambda score: score[1],meta = (language_field, 'object'))\n",
    "\n",
    "# Split the dataset to corresponding language sub-folders\n",
    "language_stats = separate_by_metadata(filtered_dataset.df, language_separated_output_path, metadata_field=language_field).compute()\n",
    "\n",
    "print(f\"Time taken for splitting language:{time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d443a5d1",
   "metadata": {},
   "source": [
    "Load `UnicodeReformatter` to reformat any unicode appeared in the desired language dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "272a5f67",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n",
      "Writing to disk complete for 1 partitions\n",
      "Time taken for fixing unicode:437.4811737537384\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "# Read the language specific data and fix the unicode in it\n",
    "lang_data_path = os.path.join(language_separated_output_path, target_language)\n",
    "lang_data = DocumentDataset.read_json(lang_data_path,add_filename=True)\n",
    "\n",
    "cleaner = Modify(UnicodeReformatter())\n",
    "cleaned_data = cleaner(lang_data)\n",
    "\n",
    "# Write the cleaned_data\n",
    "cleaned_data.to_json(lang_sep_cleaned_data_output_path, write_to_filename=True)\n",
    "\n",
    "print(f\"Time taken for fixing unicode:{time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bd57a53",
   "metadata": {},
   "source": [
    "Verify the result. We can see that some documents has been removed from TH wikipedia dataset since the number of lines in this output file is less than the original file (no. of lines = 162164)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e3329c83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\n",
      "161748 /nluo_data/NeMo-Curator/tutorials/single_node_tutorial/workspace/language_sep/data/cleaned/thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\n"
     ]
    }
   ],
   "source": [
    "! ls {lang_sep_cleaned_data_output_path}\n",
    "! wc -l  {lang_sep_cleaned_data_output_path}/thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b6cbc26",
   "metadata": {},
   "source": [
    "Furthur verify by loading documents that has been identified as other language, such as 'EN'. We can see from output that the removed document is indeed in English and contains very little or even no Thai."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "050d944c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"filename\":\"thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\",\"id\":\"1\",\"language\":\"TH\",\"source_id\":\"thwiki-20240201-thwiki-20240201-pages-articles-multistream.xml.bz2\",\"text\":\"–\\n\\nป้ายบอกทาง \\n ศาลาประชาคม – กระดานข่าว โครงการ ทรัพยากรและกิจกรรมซึ่งครอบคลุมวิกิพีเดียอย่างกว้างขวาง\\n แผนกช่วยเหลือ – ถามข้อสงสัยเกี่ยวกับการใช้งานวิกิพีเดีย\\n ปุจฉา-วิสัชนา – ถามข้อสงสัยทั่วไปที่คุณอยากรู้\\n ข่าวไซต์ – ประกาศ อัพเดต บทความและข้อมูลข่าวเกี่ยวกับวิกิพีเดียและมูลนิธิวิกิมีเดีย\\n สภากาแฟ – สำหรับอภิปรายเกี่ยวกับวิกิพีเดีย รวมถึงรายงานปัญหาเทคนิคและเสนอนโยบาย\\n Local Embassy – For Wikipedia-related discussion in languages other than Thai.\\n สร้างบทความใหม่ – บทช่วยสอนสำหรับเตรียมพร้อมสร้างบทความแรกของคุณ\\n\\nภาษาอื่น \\n\\n \",\"title\":\"หน้าหลัก\",\"url\":\"https:\\/\\/th.wikipedia.org\\/wiki\\/%E0%B8%AB%E0%B8%99%E0%B9%89%E0%B8%B2%E0%B8%AB%E0%B8%A5%E0%B8%B1%E0%B8%81\"}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "check_jsonl_file(os.path.join(language_separated_output_path,'EN'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d17f010",
   "metadata": {},
   "source": [
    "**[Optional]** Close the Dask cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "7e64cc35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client.cluster.close()\n",
    "# client.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d46cece",
   "metadata": {},
   "source": [
    "## 3.Add ID\n",
    "TH wikipedia data do have `id` field, but the `id` field contains number only. It will be better if we unified the `id` field and transform it to the format of `<prefix>_<id>`. In this way, when handling multiple dataset, we will be able to know which document from which dataset has been removed. This `id` will be useful when we are running deduplication and heuristic filtering. The function we will be using is `AddID()`. Arguments for this function include:\n",
    "- `id_field`: fields will be added to input .json file. If the key already exists in the .jsonl, it's value will be replaced.\n",
    "- `id_prefix`: prefix used in ID. Default is 'doc_id'\n",
    "- `start_index`: starting index in ID. Default is None. When set to None, an unordered ID scheme will be used for fast calculation. In this notebook, it's set to 0 for easier reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5f788b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator import AddId"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd17be33",
   "metadata": {},
   "source": [
    "**[Optional]** If there is no running Dask cluster, start CPU based Dask cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5ba1d54a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cluster = LocalCluster(n_workers=10, processes=True, memory_limit='16GB')\n",
    "# client = Client(cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12f59d5e",
   "metadata": {},
   "source": [
    "Define relevant parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "843eba7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "add_id_input_data_dir = lang_sep_cleaned_data_output_path\n",
    "\n",
    "#Output\n",
    "added_id_output_path = os.path.join(data_dir,\"add_id/cleaned\")\n",
    "\n",
    "#Format of output ID will be <prefix>_<id>, Define prefix here\n",
    "add_ID_id_prefix=\"TH_wiki\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7a8307c",
   "metadata": {},
   "source": [
    "Adding ID to dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b7a91bf1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n",
      "Writing to disk complete for 1 partitions\n",
      "Time taken for add ID:47.33783745765686\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "# Read input files\n",
    "dataset = DocumentDataset.read_json(add_id_input_data_dir,add_filename=True)\n",
    "\n",
    "# Run AddID() on the input dataset\n",
    "add_id = AddId(id_field='id',id_prefix=add_ID_id_prefix,start_index=0)\n",
    "id_dataset = add_id(dataset)\n",
    "\n",
    "#Output files\n",
    "id_dataset.to_json(added_id_output_path, write_to_filename=True)\n",
    "\n",
    "print(f\"Time taken for add ID:{time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e92b5dab",
   "metadata": {},
   "source": [
    "Verify the result. From the output, we can see that the `id` value has been changed to `TH_wiki-0000000000` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e585cedd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"filename\":\"thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl\",\"id\":\"TH_wiki-0000000000\",\"language\":\"TH\",\"source_id\":\"thwiki-20240201-thwiki-20240201-pages-articles-multistream.xml.bz2\",\"text\":\"–\\n\\nป้ายบอกทาง \\n ศาลาประชาคม – กระดานข่าว โครงการ ทรัพยากรและกิจกรรมซึ่งครอบคลุมวิกิพีเดียอย่างกว้างขวาง\\n แผนกช่วยเหลือ – ถามข้อสงสัยเกี่ยวกับการใช้งานวิกิพีเดีย\\n ปุจฉา-วิสัชนา – ถามข้อสงสัยทั่วไปที่คุณอยากรู้\\n ข่าวไซต์ – ประกาศ อัพเดต บทความและข้อมูลข่าวเกี่ยวกับวิกิพีเดียและมูลนิธิวิกิมีเดีย\\n สภากาแฟ – สำหรับอภิปรายเกี่ยวกับวิกิพีเดีย รวมถึงรายงานปัญหาเทคนิคและเสนอนโยบาย\\n Local Embassy – For Wikipedia-related discussion in languages other than Thai.\\n สร้างบทความใหม่ – บทช่วยสอนสำหรับเตรียมพร้อมสร้างบทความแรกของคุณ\\n\\nภาษาอื่น \\n\\n \",\"title\":\"หน้าหลัก\",\"url\":\"https:\\/\\/th.wikipedia.org\\/wiki\\/%E0%B8%AB%E0%B8%99%E0%B9%89%E0%B8%B2%E0%B8%AB%E0%B8%A5%E0%B8%B1%E0%B8%81\"}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "check_jsonl_file(added_id_output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cbddf6e",
   "metadata": {},
   "source": [
    "Close Dask cluster. This cell needs to be run as we are starting a new GPU Dask cluster in the following task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4daa1f2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "client.cluster.close()\n",
    "client.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1baf027e",
   "metadata": {},
   "source": [
    "## 4.Exact Dedplication\n",
    "\n",
    "In exact deduplication, the document text is hashed into unique string using certain hashing algorithm, such as 'md5'. The documents with exact hashed values are having identical text. We will output the `ID` of duplicated documents for removal later. The function used is `ExactDuplicates()`. Arguments for this function include:\n",
    "- `id_field`: Key in input file for identifying document ID\n",
    "- `text_field`: Key in input file which contains document text.\n",
    "- `hash_method`: Hashing algorithm used. Default is `md5`\n",
    "- `cache_dir`: If specified, the duplicated document IDs will be output to the `cache_dir`. Otherwise, the IDs will not be saved\n",
    "\n",
    "Also, we are going to use GPU dask cluster to accelerate computation for deduplication (both exact and fuzzy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3f7ba34c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator.modules import ExactDuplicates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e268cfca",
   "metadata": {},
   "source": [
    "Start a GPU based Dask cluster. Since GPU based Dask cluster involves setting several arguments, we will use the `get_client()` wrapper function to quickly set up. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4b73e5f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of dask worker:1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'tcp://127.0.0.1:36179': None}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client = get_client(cluster_type = 'gpu', set_torch_to_use_rmm=False)\n",
    "print(f\"Number of dask worker:{get_num_workers(client)}\")\n",
    "client.run(pre_imports)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fc99440",
   "metadata": {},
   "source": [
    "If you encounter the following error\n",
    "`get_client() missing 1 required positional argument: 'args'`:\n",
    "\n",
    "This is probably because the `nemo_curator` library is not updated to the newer version. Please run the following line in the terminal, following instruction in our [GitHub](https://github.com/nicoleeeluo/NeMo-Curator/tree/main) repo, and restart the notebook. Intermediate result of the previous section has been saved to local, you can start from this section after updating."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a590c78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pip install --extra-index-url https://pypi.nvidia.com \".[cuda12x]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0151abe0",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "54b627a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "exact_dedup_input_dataset_dir = added_id_output_path\n",
    "\n",
    "#Output\n",
    "exact_dedup_base_output_path = os.path.join(data_dir,\"exact_dedup\")\n",
    "exact_dedup_log_dir = os.path.join(exact_dedup_base_output_path,'log')\n",
    "exact_dedup_output_dir = os.path.join(exact_dedup_base_output_path,'data')\n",
    "\n",
    "#Parameters for ExactDuplicates()\n",
    "exact_dedup_dataset_id_field = \"id\"\n",
    "exact_dedup_dataset_text_field = \"text\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ede2e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p {exact_dedup_log_dir}\n",
    "!mkdir -p {exact_dedup_output_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1882204a",
   "metadata": {},
   "source": [
    "Apply exact deduplication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dfaaa765",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/nemo_curator/modules/exact_dedup.py:158: UserWarning: Output path f/work_dir/tutorials/single_node_tutorial/workspace/exact_dedup/data/_exact_duplicates.parquet already exists and will be overwritten\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of exact duplicated file:53\n",
      "Time taken for exact duplicate:1.9788782596588135\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "# Read input dataset\n",
    "input_dataset = DocumentDataset.read_json(exact_dedup_input_dataset_dir, backend='cudf')\n",
    "\n",
    "#Run exact deduplication to the input\n",
    "exact_dup = ExactDuplicates(\n",
    "    logger=exact_dedup_log_dir,\n",
    "    id_field=exact_dedup_dataset_id_field,\n",
    "    text_field=exact_dedup_dataset_text_field,\n",
    "    hash_method=\"md5\",\n",
    "    cache_dir=exact_dedup_output_dir #Duplicated document ID list is output to the cache_dir\n",
    ")\n",
    "duplicates = exact_dup(dataset=input_dataset)\n",
    "\n",
    "print(f\"Number of exact duplicated file:{len(duplicates)}\")\n",
    "\n",
    "print(f\"Time taken for exact duplicate:{time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e68f0399",
   "metadata": {},
   "source": [
    "Verify the output duplicated ID. We can group by the `_hashes` to get the list of duplicated documents having the same _hashes and use `extract_lines_with_id()` to verify that those documents are indeed exact duplicates. Please note that the `id` might changes, therefore, please replace the `target_list` when necessary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "28d8bb0b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of exact duplicated document:53\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>_hashes</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TH_wiki-0000122055</td>\n",
       "      <td>3e6e96a80410d5a191d098f464e66f86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>TH_wiki-0000105191</td>\n",
       "      <td>e77a248506ef16737288fae5759db33a</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>TH_wiki-0000105192</td>\n",
       "      <td>2e386f5c3af70f43874618988d4842b2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>TH_wiki-0000105193</td>\n",
       "      <td>2e386f5c3af70f43874618988d4842b2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>TH_wiki-0000105194</td>\n",
       "      <td>2e386f5c3af70f43874618988d4842b2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   id                           _hashes\n",
       "0  TH_wiki-0000122055  3e6e96a80410d5a191d098f464e66f86\n",
       "1  TH_wiki-0000105191  e77a248506ef16737288fae5759db33a\n",
       "2  TH_wiki-0000105192  2e386f5c3af70f43874618988d4842b2\n",
       "3  TH_wiki-0000105193  2e386f5c3af70f43874618988d4842b2\n",
       "4  TH_wiki-0000105194  2e386f5c3af70f43874618988d4842b2"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exact_dedup_res = pd.read_parquet(os.path.join(exact_dedup_output_dir,\"_exact_duplicates.parquet\"))\n",
    "print(f\"Number of exact duplicated document:{len(exact_dedup_res)}\")\n",
    "exact_dedup_res.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fca41870",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>_hashes</th>\n",
       "      <th>id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0b908a91cdf0544c1ef3015cff4ee07e</td>\n",
       "      <td>TH_wiki-0000157216 TH_wiki-0000066307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>15f35c239b6579b4642f7656e64576ac</td>\n",
       "      <td>TH_wiki-0000074714 TH_wiki-0000074715 TH_wiki-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1708cb56ec582f78716f0864dca9382d</td>\n",
       "      <td>TH_wiki-0000021211 TH_wiki-0000021213 TH_wiki-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2e386f5c3af70f43874618988d4842b2</td>\n",
       "      <td>TH_wiki-0000105192 TH_wiki-0000105193 TH_wiki-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3e6e96a80410d5a191d098f464e66f86</td>\n",
       "      <td>TH_wiki-0000122055 TH_wiki-0000116550</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            _hashes  \\\n",
       "0  0b908a91cdf0544c1ef3015cff4ee07e   \n",
       "1  15f35c239b6579b4642f7656e64576ac   \n",
       "2  1708cb56ec582f78716f0864dca9382d   \n",
       "3  2e386f5c3af70f43874618988d4842b2   \n",
       "4  3e6e96a80410d5a191d098f464e66f86   \n",
       "\n",
       "                                                  id  \n",
       "0              TH_wiki-0000157216 TH_wiki-0000066307  \n",
       "1  TH_wiki-0000074714 TH_wiki-0000074715 TH_wiki-...  \n",
       "2  TH_wiki-0000021211 TH_wiki-0000021213 TH_wiki-...  \n",
       "3  TH_wiki-0000105192 TH_wiki-0000105193 TH_wiki-...  \n",
       "4              TH_wiki-0000122055 TH_wiki-0000116550  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exact_dedup_res.groupby('_hashes')['id'].agg(lambda x: ' '.join(x)).reset_index().head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8c9624ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'filename': 'thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl', 'id': 'TH_wiki-0000066307', 'language': 'TH', 'source_id': 'thwiki-20240201-thwiki-20240201-pages-articles-multistream.xml.bz2', 'text': '\\n\\nแหล่งข้อมูลอื่น \\n\\nสงขลา\\n \\nรายชื่อเกี่ยวกับจังหวัดสงขลา', 'title': 'รายชื่อโบราณสถานในจังหวัดสงขลา', 'url': 'https://th.wikipedia.org/wiki/%E0%B8%A3%E0%B8%B2%E0%B8%A2%E0%B8%8A%E0%B8%B7%E0%B9%88%E0%B8%AD%E0%B9%82%E0%B8%9A%E0%B8%A3%E0%B8%B2%E0%B8%93%E0%B8%AA%E0%B8%96%E0%B8%B2%E0%B8%99%E0%B9%83%E0%B8%99%E0%B8%88%E0%B8%B1%E0%B8%87%E0%B8%AB%E0%B8%A7%E0%B8%B1%E0%B8%94%E0%B8%AA%E0%B8%87%E0%B8%82%E0%B8%A5%E0%B8%B2'}\n",
      "{'filename': 'thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl', 'id': 'TH_wiki-0000157216', 'language': 'TH', 'source_id': 'thwiki-20240201-thwiki-20240201-pages-articles-multistream.xml.bz2', 'text': '\\n\\nแหล่งข้อมูลอื่น \\n\\nสงขลา\\n \\nรายชื่อเกี่ยวกับจังหวัดสงขลา', 'title': 'รายชื่อโบราณสถานในจังหวัดสงขลา (อำเภอเมืองสงขลาและสิงหนคร)', 'url': 'https://th.wikipedia.org/wiki/%E0%B8%A3%E0%B8%B2%E0%B8%A2%E0%B8%8A%E0%B8%B7%E0%B9%88%E0%B8%AD%E0%B9%82%E0%B8%9A%E0%B8%A3%E0%B8%B2%E0%B8%93%E0%B8%AA%E0%B8%96%E0%B8%B2%E0%B8%99%E0%B9%83%E0%B8%99%E0%B8%88%E0%B8%B1%E0%B8%87%E0%B8%AB%E0%B8%A7%E0%B8%B1%E0%B8%94%E0%B8%AA%E0%B8%87%E0%B8%82%E0%B8%A5%E0%B8%B2%20%28%E0%B8%AD%E0%B8%B3%E0%B9%80%E0%B8%A0%E0%B8%AD%E0%B9%80%E0%B8%A1%E0%B8%B7%E0%B8%AD%E0%B8%87%E0%B8%AA%E0%B8%87%E0%B8%82%E0%B8%A5%E0%B8%B2%E0%B9%81%E0%B8%A5%E0%B8%B0%E0%B8%AA%E0%B8%B4%E0%B8%87%E0%B8%AB%E0%B8%99%E0%B8%84%E0%B8%A3%29'}\n"
     ]
    }
   ],
   "source": [
    "target_list = ['TH_wiki-0000157216', 'TH_wiki-0000066307']\n",
    "for line in extract_lines_with_id(os.path.join(exact_dedup_input_dataset_dir,'thwiki-20240201-pages-articles-multistream.xml.bz2.jsonl'),target_list):\n",
    "    print(line)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4013203c",
   "metadata": {},
   "source": [
    "**[Optional]** You might choose to close Dask cluster here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5ef2f05e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client.cluster.close()\n",
    "# client.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a2feadc",
   "metadata": {},
   "source": [
    "## 5. Fuzzy Deduplication\n",
    "Fuzzy deduplication involves 5 intermediate steps to generate duplicates. Refer to https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html for details\n",
    "\n",
    "Fuzzy deduplication in this example is a GPU implementation of MinhashLSH algorithm. This algorithm measures similarity based on statistics but not semantic meanings of text. There are a few concepts to be introduced before heading into fuzzy deduplication.\n",
    "1. Jaccard similarity: Jaccard similarity is often used as a metric to calculate the similarity between two sets. It's calculated by dividing the number of common elements in the two sets (Intersection) by the number of total unique elements in the two sets (Union). In the case of text documents, we transform a document into a set of n-grams. If two documents share a large amount of n-grams, most likely the documents are similar. \n",
    "\n",
    "    ![alt text](./image/jaccard.png )\n",
    "\n",
    "2. Complexity of the problem: To find all the similar document pairs in a dataset, we need to compute pair-wise Jaccard similarity across the dataset. Hence, making the complexity $O(N^2)$\n",
    "\n",
    "The MinhashLSH algorithm is a technique for quickly estimating the similarity between sets, such as the similarity between documents represented as sets of shingles (n-grams). It's able to find out Jaccard similar pair in the corpus but in a much computational efficient way. This algorithm has following steps in a high-level:\n",
    "1. Compute minhash for each document\n",
    "2. Run Locality Sensitive Hashing (LSH) based on the minhash which further assign buckets to each document. Each documents will be assigned to multiple buckets. Documents within the same bucket are deemed to be similar.\n",
    "3. Run pair-wise Jaccard similarity within each buckets to remove false positive cases within the buckets\n",
    "4. Based on the Jaccard similarity, transform the similarity matrix to a graph ans run connected component algorithm. For a group of connected components in the graph, they are the final similar document groups and the IDs within each groups will be output for duplicate removal.\n",
    "More detailed explanation please refer to https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/cpudeduplication.html.\n",
    "\n",
    "For implementation of MinhahsLSH on GPU, there are 5 steps:\n",
    "1. Minhash computation\n",
    "2. Bucket computation\n",
    "3. Jaccard shuffle for load balancing in a distributed system\n",
    "4. Jaccard similarity computation\n",
    "5. Connected component \n",
    "\n",
    "In this section, we will firstly provide examples to each sub-steps for users to have a better understanding on what is going on under the hood. At the last sub section, we will provide example for the fuzzy deduplication wrapper."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffca14ad",
   "metadata": {},
   "source": [
    "**If there is not running Dask cluster, start a GPU Dask cluster here**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e00ba2fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client = get_client(cluster_type = 'gpu', set_torch_to_use_rmm=False)\n",
    "# print(f\"Number of dask worker:{get_num_workers(client)}\")\n",
    "# client.run(pre_imports)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5df73743",
   "metadata": {},
   "source": [
    "### 5.1 Minhash\n",
    "\n",
    "Run `MinHash()` for this section. The output of a minhash is a parquet file which contains document ID and hashed value which is an array contains 260 32-bit integer data. To obtain such hashed values we need to go through the following steps:\n",
    "1. Generate a set of n-gram components of a document. For example, doc = `Nemo Curator is a data curation tool`, a 3-gram set of this document will be `['Nemo Curator is','Curator is a','is a data','a data curation','data curation tool']`\n",
    "2. Hashed each n-gram into numerical values\n",
    "3. Generate a random hash function $H_1()$ which will hash each numeric n-gram into a 32-bit integer and take the minimum integer to use as minhash value for $H_1()$\n",
    "4. Repeat step 2 and 3 with hash function $H_x()$ until desired minhash length is reached. Minhash value of each iteration will be append together to form the final minhash array. \n",
    "\n",
    "Arguments include:\n",
    "- `seed`:Random seed used for initializing the hash functions used to compute the MinHashes. It's advised to keep this value the same for different experiment for reproducibility\n",
    "- `num_hashes`:Length of each minhash array. Default is 260. Longer minhash length will have better estimate of actual Jaccard similarity, but require more computational power\n",
    "- `char_ngrams`:n-gram length\n",
    "- `use_64bit_hash`:Whether to use 64bit or 32bit hash function\n",
    "- `id_field`: Key in input file for identifying document ID\n",
    "- `text_field`: Key in input file which contains document text.\n",
    "- `cache_dir`: If specified, the intermediate result will be output to the `cache_dir`. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1fc5bff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator import MinHash"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bf9cc8d",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d600d1b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "minhash_data_path = added_id_output_path\n",
    "#Output\n",
    "minshah_base_output_path = os.path.join(data_dir,\"fuzzy/minhash\")\n",
    "minshah_log_dir = os.path.join(minshah_base_output_path,'log')\n",
    "minshah_output_dir = os.path.join(minshah_base_output_path,'data')\n",
    "#Specify dataset name\n",
    "dataset_name = 'TH_wikipedia'\n",
    "\n",
    "#Relevant parameters\n",
    "minhash_id_field = 'id'\n",
    "minhash_text_field = 'text'\n",
    "seed = 10\n",
    "minhash_length = 260\n",
    "char_ngram = 5\n",
    "use_64bit_hash = False\n",
    "files_per_partition = 2\n",
    "\n",
    "!mkdir -p {minshah_log_dir}\n",
    "!mkdir -p {minshah_output_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c31ddf4",
   "metadata": {},
   "source": [
    "Run MinHash"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "88540950",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing minhashes for /work_dir/tutorials/single_node_tutorial/workspace/add_id/cleaned\n",
      "Reading 1 files\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/nemo_curator/modules/fuzzy_dedup.py:175: UserWarning: Output path /work_dir/tutorials/single_node_tutorial/workspace/fuzzy/minhash/data/_minhashes.parquet already exists and will be overwritten\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time taken for MinHash:6.340771198272705\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "print(f\"Computing minhashes for {minhash_data_path}\")\n",
    "\n",
    "# Load data. Only the [minhash_id_field, text_field] columns are needed\n",
    "files = get_all_files_paths_under(root=minhash_data_path, recurse_subdirectories=False)\n",
    "files = [f for f in files if f.endswith(\".jsonl\")]\n",
    "df = read_data(\n",
    "    files,\n",
    "    file_type=\"jsonl\",\n",
    "    backend=\"cudf\",\n",
    "    files_per_partition=files_per_partition,\n",
    "    add_filename=False,\n",
    ")[[minhash_id_field, minhash_text_field]]\n",
    "\n",
    "# Run MinHash() on input data\n",
    "minhasher = MinHash(\n",
    "    seed=seed,\n",
    "    num_hashes=minhash_length,\n",
    "    char_ngrams=char_ngram,\n",
    "    use_64bit_hash=use_64bit_hash,\n",
    "    logger=minshah_log_dir,\n",
    "    id_field=minhash_id_field,\n",
    "    text_field=minhash_text_field,\n",
    "    cache_dir=minshah_output_dir\n",
    ")\n",
    "res = minhasher(DocumentDataset(df)).df\n",
    "\n",
    "print(f\"Time taken for MinHash:{time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "158bf3ab",
   "metadata": {},
   "source": [
    "Verify result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "10b5eb55",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>_minhash_signature</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TH_wiki-0000000000</td>\n",
       "      <td>[11565725, 19782487, 9831980, 5480992, 2306475...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>TH_wiki-0000000001</td>\n",
       "      <td>[407876, 107572, 824528, 346831, 216554, 10963...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>TH_wiki-0000000002</td>\n",
       "      <td>[727721, 694551, 233868, 346831, 216554, 77001...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>TH_wiki-0000000003</td>\n",
       "      <td>[1149282, 931656, 2515604, 1428622, 4964646, 4...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>TH_wiki-0000000004</td>\n",
       "      <td>[1559901, 11771639, 487706, 826569, 1203860, 5...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   id                                 _minhash_signature\n",
       "0  TH_wiki-0000000000  [11565725, 19782487, 9831980, 5480992, 2306475...\n",
       "1  TH_wiki-0000000001  [407876, 107572, 824528, 346831, 216554, 10963...\n",
       "2  TH_wiki-0000000002  [727721, 694551, 233868, 346831, 216554, 77001...\n",
       "3  TH_wiki-0000000003  [1149282, 931656, 2515604, 1428622, 4964646, 4...\n",
       "4  TH_wiki-0000000004  [1559901, 11771639, 487706, 826569, 1203860, 5..."
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minhash_res = pd.read_parquet(os.path.join(minshah_output_dir, \"_minhashes.parquet\"))\n",
    "minhash_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bce0f80",
   "metadata": {},
   "source": [
    "### 5.2 LSH\n",
    "`LSH()` implements LSH algorithm which includes the following steps:\n",
    "1. Divide the minhash array into `X` different portions. \n",
    "2. For each portions, hash the minhash values into buckets. One document will be assigned to `X` buckets.\n",
    "3. Documents within the same bucket will be deemed similar. Since every document will be assigned `X` buckets and as long as two documents share 1 or more buckets they are deemed similar, the result of LSH will have more false positive as compared to false negative. The false positive cases will be filtered in following modules, namely jaccard compute.\n",
    "\n",
    "Arguments include:\n",
    "- `minhash_length`:Length of minhash signature. Must be consistent with `MinHash()`\n",
    "- `num_buckets`: Number of buckets\n",
    "- `buckets_per_shuffle`: Number of buckets to shuffle concurrently\n",
    "- `id_field`: Key in input file for identifying document ID\n",
    "- `minhash_field`: Key in input file for identifying document MinHash signature \n",
    "- `cache_dir`:If specified, the intermediate result will be output to the `cache_dir`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "645b8a53",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator import LSH\n",
    "from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import \\\n",
    "    convert_str_id_to_int"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "110db216",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "738ab265",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "lsh_input_data_path = minshah_output_dir\n",
    "\n",
    "#Output\n",
    "lsh_base_output_path = os.path.join(data_dir,\"fuzzy/lsh\")\n",
    "lsh_log_dir = os.path.join(lsh_base_output_path,'log')\n",
    "lsh_output_dir = os.path.join(lsh_base_output_path,'data')\n",
    "\n",
    "#Relevant parameters\n",
    "lsh_id_field = 'id'\n",
    "minhash_field = '_minhash_signature'\n",
    "minhash_length=260\n",
    "num_bands=20\n",
    "buckets_per_shuffle=1\n",
    "\n",
    "!mkdir -p {lsh_log_dir}\n",
    "!mkdir -p {lsh_output_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5250a2a",
   "metadata": {},
   "source": [
    "Run LSH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1ef61e2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/nemo_curator/modules/fuzzy_dedup.py:361: UserWarning: Output path /work_dir/tutorials/single_node_tutorial/workspace/fuzzy/lsh/data/_buckets.parquet already exists and will be overwritten\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time taken for LSH:19.37230634689331\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "#Load MinHash output\n",
    "df = dask_cudf.read_parquet(lsh_input_data_path, blocksize=\"2GB\", aggregate_files=True, backend = \"cudf\")\n",
    "df = df.map_partitions(\n",
    "    convert_str_id_to_int,\n",
    "    id_column=lsh_id_field,\n",
    "    meta=cudf.DataFrame(\n",
    "        {minhash_field: [[1, 2, 3]], \"doc_id\": [1], \"dataset_id\": np.uint32(1)}\n",
    "    ),\n",
    ")\n",
    "\n",
    "#Run LSH()\n",
    "lsh = LSH(\n",
    "    cache_dir=lsh_output_dir,\n",
    "    num_hashes=minhash_length,\n",
    "    num_buckets=num_bands,\n",
    "    buckets_per_shuffle=buckets_per_shuffle,\n",
    "    id_fields=[\"dataset_id\", \"doc_id\"],\n",
    "    minhash_field=minhash_field,\n",
    "    logger=lsh_log_dir,\n",
    ")\n",
    "res = lsh(DocumentDataset(df))\n",
    "\n",
    "t1 = time.time()\n",
    "print(f\"Time taken for LSH:{time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad2e3b60",
   "metadata": {},
   "source": [
    "Verify result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9d0449c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset_id</th>\n",
       "      <th>doc_id</th>\n",
       "      <th>_bucket_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>123547</td>\n",
       "      <td>210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>93844</td>\n",
       "      <td>120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>66564</td>\n",
       "      <td>86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>93845</td>\n",
       "      <td>120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>66565</td>\n",
       "      <td>86</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   dataset_id  doc_id  _bucket_id\n",
       "0  1692361878  123547         210\n",
       "1  1692361878   93844         120\n",
       "2  1692361878   66564          86\n",
       "3  1692361878   93845         120\n",
       "4  1692361878   66565          86"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lsh_res = pd.read_parquet(os.path.join(lsh_output_dir, \"_buckets.parquet\"))\n",
    "lsh_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f952f074",
   "metadata": {},
   "source": [
    "### 5.3 Jaccard Shuffle\n",
    "In this section, we will be using `_MapBucket()` and `_Shuffle()`.\n",
    "\n",
    "For `_MapBucket()`, it is designed to take input text data in jsonl format and bucket information which is output of LSH, map the documents to their respective buckets, and write the resulting DataFrame containing the anchor documents and their associated bucket information to a parquet file. Arguments include:\n",
    "- `id_field`: Key in input .jsonl file for identifying document ID\n",
    "- `text_field`: Key in input .jsonl file which contains document text.\n",
    "- `bucket_field`: Key in input _buckets.parquet which contains `bucket_id`.\n",
    "- `num_anchors`: Number of anchors (document in the same buckets) to be output\n",
    "\n",
    "\n",
    "For `_Shuffle()`, it perform a shuffling operation on the documents based on their bucket assignments, output in .parquet format. This shuffling operation is a crucial step in the deduplication process, as it helps distribute similar documents across different partitions or workers, enabling efficient parallel processing and deduplication in subsequent steps. Arguments include:\n",
    "- `id_fields`: Columns in `_buckets.parquet` that maps to original `id` in .jsonl data file. In this example, it is `[\"dataset_id\", \"doc_id\"]`\n",
    "- `text_field`: Key in input .jsonl file which contains document text.\n",
    "- `int_to_str_id`:  Key in input .jsonl file for identifying document ID\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "707ea54d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (\n",
    "    get_bucket_ddf_from_parquet_path,\n",
    "    get_text_ddf_from_json_path_with_blocksize,\n",
    ")\n",
    "from nemo_curator.modules.fuzzy_dedup import _MapBuckets,_Shuffle"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f2e321d",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "70e2dff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "input_data_paths = [minhash_data_path]\n",
    "input_bucket_path = lsh_output_dir\n",
    "\n",
    "#Output\n",
    "jaccard_shuffle_base_output_path = os.path.join(data_dir,\"fuzzy/jaccard_shuffle\")\n",
    "output_anchor_docs_with_bk_path = os.path.join(jaccard_shuffle_base_output_path, \"anchor_docs_with_bk.parquet\")\n",
    "input_anchor_docs_with_bk_dir = output_anchor_docs_with_bk_path\n",
    "jaccard_shuffle_log_path = os.path.join(jaccard_shuffle_base_output_path,\"log\")\n",
    "output_shuffled_docs_path = os.path.join(jaccard_shuffle_base_output_path, \"shuffled_docs.parquet\")\n",
    "\n",
    "#Relevant parameters for _MapBucket()\n",
    "text_ddf_blocksize = 256\n",
    "bucket_mapping_ddf_blocksize = 256\n",
    "num_files = None\n",
    "shuffle_type ='tasks'\n",
    "input_bucket_field = '_bucket_id'\n",
    "input_id_field = 'id'\n",
    "input_text_field = 'text'\n",
    "\n",
    "#Relevant parameters for _Shuffle()\n",
    "shuffle_id_fields=[\"dataset_id\", \"doc_id\"]\n",
    "int_to_str_id='id'\n",
    "\n",
    "!mkdir -p {jaccard_shuffle_base_output_path}\n",
    "!mkdir -p {jaccard_shuffle_log_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0f19efa",
   "metadata": {},
   "source": [
    "Run Jaccard map bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b2850b0a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of files being read for jaccard calculation = 1\n",
      "Number of ddf_bk partitions = 1\n",
      "Time taken for Bucket Mapping:1.239295244216919 s\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "num_workers = get_num_workers(client)\n",
    "\n",
    "# Read .jsonl input data\n",
    "ddf_text = get_text_ddf_from_json_path_with_blocksize(\n",
    "    input_data_paths=input_data_paths,\n",
    "    num_files=num_files,\n",
    "    blocksize=text_ddf_blocksize,\n",
    "    id_column=input_id_field,\n",
    "    text_column=input_text_field,\n",
    ")\n",
    "# Read \"_buckets.parquet\"\n",
    "ddf_bk = get_bucket_ddf_from_parquet_path(input_bucket_path=input_bucket_path, num_workers=num_workers)\n",
    "\n",
    "#Run _MapBuckets()\n",
    "map_buckets = _MapBuckets(id_fields=shuffle_id_fields, bucket_field=input_bucket_field, logger=jaccard_shuffle_log_path)\n",
    "ddf_anchor_docs_with_bk = map_buckets.map_buckets_with_anchors(documents_df=ddf_text, buckets_df=ddf_bk, shuffle_type=shuffle_type)\n",
    "\n",
    "#Write to disk\n",
    "ddf_anchor_docs_with_bk.to_parquet(output_anchor_docs_with_bk_path, write_index=False)\n",
    "\n",
    "print(f\"Time taken for Bucket Mapping:{time.time()-t0} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1533a15",
   "metadata": {},
   "source": [
    "Verify result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d74012c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset_id</th>\n",
       "      <th>doc_id</th>\n",
       "      <th>anchor_1_dataset_id</th>\n",
       "      <th>anchor_1_doc_id</th>\n",
       "      <th>anchor_0_dataset_id</th>\n",
       "      <th>anchor_0_doc_id</th>\n",
       "      <th>_output_partition_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>127258</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>127781</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>126955</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>85383</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>85364</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>85374</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>45030</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>85200</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>45030</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>127259</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>127781</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>126955</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>127968</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>127961</td>\n",
       "      <td>1692361878</td>\n",
       "      <td>127996</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   dataset_id  doc_id  anchor_1_dataset_id  anchor_1_doc_id  \\\n",
       "0  1692361878  127258           1692361878           127781   \n",
       "1  1692361878   85383           1692361878            85364   \n",
       "2  1692361878   45030           1692361878            85200   \n",
       "3  1692361878  127259           1692361878           127781   \n",
       "4  1692361878  127968           1692361878           127961   \n",
       "\n",
       "   anchor_0_dataset_id  anchor_0_doc_id  _output_partition_id  \n",
       "0           1692361878           126955                     0  \n",
       "1           1692361878            85374                     0  \n",
       "2           1692361878            45030                     0  \n",
       "3           1692361878           126955                     0  \n",
       "4           1692361878           127996                     0  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "map_bucket_res = pd.read_parquet(output_anchor_docs_with_bk_path)\n",
    "map_bucket_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1487b1ad",
   "metadata": {},
   "source": [
    "**[Optional]** Remove previous Jaccard Shuffle results. Run only when there are files under the Jaccard Shuffle output path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b414f703",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!rm -r {output_shuffled_docs_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f33a6782",
   "metadata": {},
   "source": [
    "Run Jaccard Shuffle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "86d1b3e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Started processing bucket-map partitions 0 through 1 of 1\n",
      "Using 1 text partitions.\n",
      "Starting text bytes aware shuffle\n",
      "Will write 30596 rows to disk\n",
      "Text-df partition  1/1 completed in 2.4342942237854004\n",
      "Bucket partition  1/1 completed in 2.4410006999969482\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time taken for Jaccard Shuffle = 2.4802186489105225 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "#Run _Shuffle() on results of _MapBucket()\n",
    "shuffle = _Shuffle(\n",
    "    id_fields=shuffle_id_fields,\n",
    "    text_field=input_text_field,\n",
    "    int_to_str_id=int_to_str_id,\n",
    "    logger=jaccard_shuffle_log_path\n",
    ")\n",
    "shuffle.shuffle_docs_on_buckets(\n",
    "    documents_df=ddf_text,\n",
    "    bucket_w_anchors_path=input_anchor_docs_with_bk_dir,\n",
    "    output_shuffled_docs_path=output_shuffled_docs_path,\n",
    "    bucket_mapping_df_blocksize=bucket_mapping_ddf_blocksize,\n",
    "#     parts_per_worker=1,\n",
    "#     bucket_parts_per_worker=8,\n",
    "    partition_on=\"_output_partition_id\",\n",
    ")\n",
    "\n",
    "print(f\"Time taken for Jaccard Shuffle = {time.time()-t0} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86b06cb5",
   "metadata": {},
   "source": [
    "Verify result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "1b51a5fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>_text_bytes</th>\n",
       "      <th>id</th>\n",
       "      <th>anchor_0_id</th>\n",
       "      <th>anchor_1_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>การแข่งขันกีฬากรีฑาในโอลิมปิกฤดูร้อน 2020 – เด...</td>\n",
       "      <td>1457</td>\n",
       "      <td>1692361878-135417</td>\n",
       "      <td>1692361878-135463</td>\n",
       "      <td>1692361878-135417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>การแข่งขันกีฬากรีฑาในโอลิมปิกฤดูร้อน 2020 – เด...</td>\n",
       "      <td>1457</td>\n",
       "      <td>1692361878-135417</td>\n",
       "      <td>1692361878-135392</td>\n",
       "      <td>1692361878-135447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>สุริยุปราคาบางส่วนจะเกิดขึ้นในวันที่ 13 กรกฎาค...</td>\n",
       "      <td>1262</td>\n",
       "      <td>1692361878-83363</td>\n",
       "      <td>1692361878-94231</td>\n",
       "      <td>1692361878-83363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>สุริยุปราคาบางส่วนจะเกิดขึ้นในวันที่ 13 กรกฎาค...</td>\n",
       "      <td>1262</td>\n",
       "      <td>1692361878-83363</td>\n",
       "      <td>1692361878-94905</td>\n",
       "      <td>1692361878-83363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>สุริยุปราคาบางส่วนจะเกิดขึ้นในวันที่ 13 กรกฎาค...</td>\n",
       "      <td>1262</td>\n",
       "      <td>1692361878-83363</td>\n",
       "      <td>1692361878-94906</td>\n",
       "      <td>1692361878-94905</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                text  _text_bytes  \\\n",
       "0  การแข่งขันกีฬากรีฑาในโอลิมปิกฤดูร้อน 2020 – เด...         1457   \n",
       "1  การแข่งขันกีฬากรีฑาในโอลิมปิกฤดูร้อน 2020 – เด...         1457   \n",
       "2  สุริยุปราคาบางส่วนจะเกิดขึ้นในวันที่ 13 กรกฎาค...         1262   \n",
       "3  สุริยุปราคาบางส่วนจะเกิดขึ้นในวันที่ 13 กรกฎาค...         1262   \n",
       "4  สุริยุปราคาบางส่วนจะเกิดขึ้นในวันที่ 13 กรกฎาค...         1262   \n",
       "\n",
       "                  id        anchor_0_id        anchor_1_id  \n",
       "0  1692361878-135417  1692361878-135463  1692361878-135417  \n",
       "1  1692361878-135417  1692361878-135392  1692361878-135447  \n",
       "2   1692361878-83363   1692361878-94231   1692361878-83363  \n",
       "3   1692361878-83363   1692361878-94905   1692361878-83363  \n",
       "4   1692361878-83363   1692361878-94906   1692361878-94905  "
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jaccard_shuffle_res = pd.read_parquet(os.path.join(output_shuffled_docs_path,\"_output_partition_id=0/batch_1_1.parquet\"))\n",
    "jaccard_shuffle_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8644e51",
   "metadata": {},
   "source": [
    "### 5.4 Jaccard Compute\n",
    "We will be using `JaccardSimilarity()`.This is to computes the Jaccard similarity between document pairs. Result is a parquet dataset consisting of document id pair along with their Jaccard similarity score. To compute Jaccard similarity between two documents, we first convert the document into sets of n-grams and then compute the Jaccard similarity of the two sets.\n",
    "\n",
    "Arguments include:\n",
    "- `id_field`: Column in input .parquet file identifying document ID\n",
    "- `text_field`: Column in input .parquet file identifying document text\n",
    "- `anchor_id_fields`: Column in input .parquet file identifying anchors. This can be generated by specifying number of anchor used in `_MapBucket` whose default value is 2\n",
    "- `ngram_width`: n-gram used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "b1a532a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9e65975",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "291d3aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "shuffled_docs_path = output_shuffled_docs_path\n",
    "\n",
    "#Output\n",
    "jaccard_compute_base_output_path = os.path.join(data_dir,\"fuzzy/jaccard_compute\")\n",
    "jaccard_compute_output_results_path = os.path.join(jaccard_compute_base_output_path, \"jaccard_similarity_results.parquet\")\n",
    "\n",
    "#Relevant parameters\n",
    "input_id_field = 'id'\n",
    "input_text_field = 'text'\n",
    "ngram_size = 5\n",
    "num_anchors = 2\n",
    "\n",
    "!mkdir -p {jaccard_compute_base_output_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9341b58c",
   "metadata": {},
   "source": [
    "Run Jaccard Compute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9b1b9bdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running jaccard compute script\n",
      "Time taken for Jaccard Computing: 0.735356330871582\n"
     ]
    }
   ],
   "source": [
    "# enable_spilling()\n",
    "# client.run(enable_spilling)\n",
    "\n",
    "print(\"Running jaccard compute script\", flush=True)\n",
    "t0 = time.time()\n",
    "\n",
    "jaccard = JaccardSimilarity(\n",
    "    id_field=input_id_field,\n",
    "    text_field=input_text_field,\n",
    "    anchor_id_fields=[f\"anchor_{i}_{input_id_field}\" for i in range(num_anchors)],\n",
    "    ngram_width=ngram_size,\n",
    ")\n",
    "\n",
    "#Load and run Jaccard compute\n",
    "result_df = jaccard.jaccard_compute(shuffled_docs_path)\n",
    "\n",
    "result_df.to_parquet(jaccard_compute_output_results_path, write_index=False, write_metadata_file=False)\n",
    "\n",
    "print(f\"Time taken for Jaccard Computing: {time.time()-t0}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb740d30",
   "metadata": {},
   "source": [
    "Verify output. You might see that there are repeated `id_x` and `id_y` pairs. This is expected as a pair of similar documents is likely to share numerous same buckets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "a41d1f09",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id_x</th>\n",
       "      <th>id_y</th>\n",
       "      <th>jaccard</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1692361878-136568</td>\n",
       "      <td>1692361878-136566</td>\n",
       "      <td>0.754448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1692361878-136568</td>\n",
       "      <td>1692361878-136566</td>\n",
       "      <td>0.754448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1692361878-136568</td>\n",
       "      <td>1692361878-136566</td>\n",
       "      <td>0.754448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1692361878-136568</td>\n",
       "      <td>1692361878-136566</td>\n",
       "      <td>0.754448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1692361878-92875</td>\n",
       "      <td>1692361878-87743</td>\n",
       "      <td>0.828794</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                id_x               id_y   jaccard\n",
       "0  1692361878-136568  1692361878-136566  0.754448\n",
       "1  1692361878-136568  1692361878-136566  0.754448\n",
       "2  1692361878-136568  1692361878-136566  0.754448\n",
       "3  1692361878-136568  1692361878-136566  0.754448\n",
       "4   1692361878-92875   1692361878-87743  0.828794"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jaccard_compute_res = pd.read_parquet(jaccard_compute_output_results_path)\n",
    "jaccard_compute_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a505402e",
   "metadata": {},
   "source": [
    "### 5.5 Connected Components\n",
    "This section uses `ConnectedComponents()`.This section takes a dataset consisting of document pairs and their corresponding jaccard similarity to construct a non-directed graph. A edge will be form between documents whose Jaccard similarity is higher than the threshold (0.8 in this example). It will then identify the connected components in this graph. Documents within the same connected components are deemed duplicated\n",
    "\n",
    "Arguments include:\n",
    "- `cache_dir`:Output path for intermediate results\n",
    "- `jaccard_pairs_path`:Input path for `jaccard_similarity_results.parquet`\n",
    "- `id_column`:prefix of ID column in `jaccard_similarity_results.parquet`\n",
    "- `jaccard_threshold`:Threshold to determine if an edge exists between two documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "3bff521b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator.modules.fuzzy_dedup import ConnectedComponents"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8afed6a",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "b40735dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "jaccard_pairs_path = jaccard_compute_output_results_path\n",
    "\n",
    "#Output\n",
    "connected_component_base_output_path = os.path.join(data_dir,\"fuzzy/cc\")\n",
    "connected_component_output_path = os.path.join(connected_component_base_output_path, \"connected_components.parquet\")\n",
    "connected_component_cache_dir = os.path.join(connected_component_base_output_path, \"cache\")\n",
    "\n",
    "#Relevant parameters\n",
    "input_id_field = 'id'\n",
    "jaccard_threshold = 0.8\n",
    "\n",
    "!mkdir -p {connected_component_base_output_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33d8957f",
   "metadata": {},
   "source": [
    "Run Connected Component"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "fe62dd51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch_id = 0/1, time = 0.29015278816223145\n",
      "# of groups 5465\n",
      "# of docs removed 3079\n",
      "assert num_nodes:8544==labels_df:8544 passed\n",
      "Time taken for Connected Component: 4.489336729049683 s\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "    \n",
    "components_stage = ConnectedComponents(\n",
    "    cache_dir=connected_component_cache_dir,\n",
    "    jaccard_pairs_path=jaccard_pairs_path,\n",
    "    id_column=input_id_field,\n",
    "    convert_str_ids=True,\n",
    "    jaccard_threshold=jaccard_threshold,\n",
    ")\n",
    "\n",
    "#Load and run connected component\n",
    "components_stage.cc_workflow(output_path=connected_component_output_path)\n",
    "print(f\"Time taken for Connected Component: {time.time()-t0} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "669495ee",
   "metadata": {},
   "source": [
    "Verify the result of `Connected Components`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "efbd6973",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset_id</th>\n",
       "      <th>doc_id</th>\n",
       "      <th>group</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>122282</td>\n",
       "      <td>903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>139772</td>\n",
       "      <td>1952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>93927</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>121450</td>\n",
       "      <td>2046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>85288</td>\n",
       "      <td>3030</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   dataset_id  doc_id  group\n",
       "0  1692361878  122282    903\n",
       "1  1692361878  139772   1952\n",
       "2  1692361878   93927    112\n",
       "3  1692361878  121450   2046\n",
       "4  1692361878   85288   3030"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cc_compute_res = pd.read_parquet(connected_component_output_path)\n",
    "cc_compute_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c3e2bdc",
   "metadata": {},
   "source": [
    "Let's check if the output fuzzy duplicated documents within the same group are similar. Please note that the `group` id in your output might be different from the notebook output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "d8fa1e8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>group</th>\n",
       "      <th>doc_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>75</td>\n",
       "      <td>160982, 161038, 161124, 161109, 161121, 160991...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>112</td>\n",
       "      <td>122007, 122124, 122020, 122282, 122010, 122134...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>151</td>\n",
       "      <td>134584, 135030, 134908, 134891, 135029, 135020...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>321</td>\n",
       "      <td>94082, 94114, 94126, 94057, 94121, 94132, 9411...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>339</td>\n",
       "      <td>116230, 116237, 116223, 116236, 116176, 116204...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5460</th>\n",
       "      <td>8539</td>\n",
       "      <td>120646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5461</th>\n",
       "      <td>8540</td>\n",
       "      <td>158174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5462</th>\n",
       "      <td>8541</td>\n",
       "      <td>132405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5463</th>\n",
       "      <td>8542</td>\n",
       "      <td>49199</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5464</th>\n",
       "      <td>8543</td>\n",
       "      <td>160924</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5465 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      group                                             doc_id\n",
       "0        75  160982, 161038, 161124, 161109, 161121, 160991...\n",
       "1       112  122007, 122124, 122020, 122282, 122010, 122134...\n",
       "2       151  134584, 135030, 134908, 134891, 135029, 135020...\n",
       "3       321  94082, 94114, 94126, 94057, 94121, 94132, 9411...\n",
       "4       339  116230, 116237, 116223, 116236, 116176, 116204...\n",
       "...     ...                                                ...\n",
       "5460   8539                                             120646\n",
       "5461   8540                                             158174\n",
       "5462   8541                                             132405\n",
       "5463   8542                                              49199\n",
       "5464   8543                                             160924\n",
       "\n",
       "[5465 rows x 2 columns]"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cc_compute_res['doc_id'] = cc_compute_res['doc_id'].astype(str)\n",
    "cc_compute_res.groupby('group')['doc_id'].agg(lambda x: ', '.join(x)).reset_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f34b8140",
   "metadata": {},
   "source": [
    "Change the `group` number if necessary. By running the code below, we can obtain a list of near duplicated documents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "fd01f5fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset_id</th>\n",
       "      <th>doc_id</th>\n",
       "      <th>group</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>420</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>122007</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>425</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>122124</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>689</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>122020</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>764</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>122282</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>952</th>\n",
       "      <td>1692361878</td>\n",
       "      <td>122010</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     dataset_id  doc_id  group\n",
       "420  1692361878  122007    112\n",
       "425  1692361878  122124    112\n",
       "689  1692361878  122020    112\n",
       "764  1692361878  122282    112\n",
       "952  1692361878  122010    112"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cc_compute_res[cc_compute_res['group']==112].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99a8d732",
   "metadata": {},
   "source": [
    "Print the text of near duplicated document. Please replace the `id` if necessary, `id` should be in the format of `<dataset_id>_<doc_id>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "68883f58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['ประเทศสวิตเซอร์แลนด์ ได้เข้าร่วมแข่งขันกีฬาโอลิมปิกเยาวชนฤดูหนาว ครั้งที่ 3 ค.ศ. 2020 (พ.ศ. 2563) ณ เมืองโลซาน ประเทศสวิตเซอร์แลนด์ ระหว่างวันที่ 9 - 22 มกราคม พ.ศ. 2563 คณะกรรมการโอลิมปิกแห่งชาติสวิตเซอร์แลนด์ได้ส่งทีมนักกีฬาเข้าแข่งขันทั้งหมด 56 คน แบ่งเป็นเป็นชาย 32 คนและหญิง 56 คน เข้าร่วมการแข่งขันใน 15 ชนิดกีฬา\\n\\nจำนวนผู้เข้าแข่งขัน\\n\\nผลการแข่งขัน\\n\\nสเกตลีลา\\n\\nสเกตความเร็ว\\n\\nสเกตความเร็วระยะสั้น\\n\\nฮอกกี้น้ำแข็ง\\n\\nเคอร์ลิง\\n\\nสกีลงเขา\\n\\nสกีข้ามทุ่ง\\n\\nสกีกระโดดไกล\\n\\nสกีนอร์ดิกผสม\\n\\nสกีลีลา\\n\\nสกีปีนเขา\\n\\nสโนว์บอร์ด\\n\\nทวิกีฬาฤดูหนาว\\n\\nบอบสเล\\n\\nสเกเลตัน\\n\\nอ้างอิง\\n\\nแหล่งข้อมูลอื่น \\n เว็บไซต์อย่างเป็นทางการ \\n\\nประเทศสวิตเซอร์แลนด์ในโอลิมปิกเยาวชน\\nประเทศที่เข้าร่วมแข่งขันโอลิมปิกเยาวชนฤดูหนาว 2020',\n",
       "       'ประเทศบัลแกเรีย ได้เข้าร่วมแข่งขันกีฬาโอลิมปิกเยาวชนฤดูหนาว ครั้งที่ 3 ค.ศ. 2020 (พ.ศ. 2563) ณ เมืองโลซาน ประเทศสวิตเซอร์แลนด์ ระหว่างวันที่ 9 - 22 มกราคม พ.ศ. 2563 คณะกรรมการโอลิมปิกแห่งชาติบัลแกเรียได้ส่งทีมนักกีฬาเข้าแข่งขันทั้งหมด 18 คน แบ่งเป็นเป็นชาย 11 คนและหญิง 7 คน เข้าร่วมการแข่งขันใน 8 ชนิดกีฬา\\n\\nจำนวนผู้เข้าแข่งขัน\\n\\nผลการแข่งขัน\\n\\nสเกตลีลา\\n\\nสเกตความเร็ว\\n\\nสเกตความเร็วระยะสั้น\\n\\nฮอกกี้น้ำแข็ง\\n\\nเคอร์ลิง\\n\\nสกีลงเขา\\n\\nสกีข้ามทุ่ง\\n\\nสกีกระโดดไกล\\n\\nสกีนอร์ดิกผสม\\n\\nสกีลีลา\\n\\nสกีปีนเขา\\n\\nสโนว์บอร์ด\\n\\nทวิกีฬาฤดูหนาว\\n\\nลูช\\n\\nบอบสเล\\n\\nสเกเลตัน\\n\\nอ้างอิง\\n\\nแหล่งข้อมูลอื่น \\n เว็บไซต์อย่างเป็นทางการ \\n\\nประเทศบัลแกเรียในโอลิมปิกเยาวชน\\nประเทศที่เข้าร่วมแข่งขันโอลิมปิกเยาวชนฤดูหนาว 2020'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jaccard_shuffle_res[jaccard_shuffle_res['id'].isin(['1692361878-121545','1692361878-121487'])]['text'].unique()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b6578b4",
   "metadata": {},
   "source": [
    "Below is the English translation of the output above. We can see that the two documents are indeed very similar to each other.\n",
    "- `Text 1`:\n",
    "```\n",
    "Switzerland participated in the 3rd Youth Olympic Winter Games in 2020 (B.E. 2563) in Lausanne, Switzerland from January 9 - 22, 2563. The Swiss Olympic Committee sent a total of 56 athletes, consisting of 32 men and 56 women, to compete in 15 sports.\n",
    "Number of Competitors:\n",
    "Competition Results:\n",
    "Figure Skating\n",
    "Speed Skating\n",
    "Short Track Speed Skating\n",
    "Ice Hockey\n",
    "Curling\n",
    "Alpine Skiing\n",
    "Cross-Country Skiing\n",
    "Ski Jumping\n",
    "Nordic Combined\n",
    "Freestyle Skiing\n",
    "Ski Mountaineering\n",
    "Snowboard\n",
    "Biathlon\n",
    "Bobsleigh\n",
    "Skeleton\n",
    "References:\n",
    "Other Resources:\n",
    "Official Website\n",
    "Switzerland at the Youth Olympics\n",
    "Countries at the 2020 Youth Winter Olympics\n",
    "```\n",
    "- `Text 2`:\n",
    "```\n",
    "Bulgaria participated in the 3rd Youth Olympic Winter Games in 2020 (B.E. 2563) in Lausanne, Switzerland from January 9 - 22, 2563. The Bulgarian Olympic Committee sent a total of 18 athletes, consisting of 11 men and 7 women, to compete in 8 sports.\n",
    "Number of Competitors:\n",
    "Competition Results:\n",
    "Figure Skating\n",
    "Speed Skating\n",
    "Short Track Speed Skating\n",
    "Ice Hockey\n",
    "Curling\n",
    "Alpine Skiing\n",
    "Cross-Country Skiing\n",
    "Ski Jumping\n",
    "Nordic Combined\n",
    "Freestyle Skiing\n",
    "Ski Mountaineering\n",
    "Snowboard\n",
    "Biathlon\n",
    "Luge\n",
    "Bobsleigh\n",
    "Skeleton\n",
    "References:\n",
    "Other Resources:\n",
    "Official Website\n",
    "Bulgaria at the Youth Olympics\n",
    "Countries at the 2020 Youth Winter Olympics\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f36436f3",
   "metadata": {},
   "source": [
    "### 5.6 Fuzzy deduplication wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "eb52ec06",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator import FuzzyDuplicates, FuzzyDuplicatesConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "625c1828",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "fuzzy_dedup_data_path = added_id_output_path\n",
    "#Output\n",
    "fuzzy_dedup_base_output_path = os.path.join(data_dir,\"fuzzy_wrapper\")\n",
    "fuzzy_dedup_log_dir = os.path.join(fuzzy_dedup_base_output_path,'log')\n",
    "fuzzy_dedup_cache_dir = os.path.join(fuzzy_dedup_base_output_path,'cache')\n",
    "fuzzy_dedup_output_dir = os.path.join(fuzzy_dedup_base_output_path,'data')\n",
    "#Specify dataset name\n",
    "dataset_name = 'TH_wikipedia'\n",
    "\n",
    "#Relevant parameters\n",
    "id_field = 'id'\n",
    "text_field = 'text'\n",
    "filetype = \"parquet\"\n",
    "\n",
    "!mkdir -p {fuzzy_dedup_base_output_path}\n",
    "!mkdir -p {fuzzy_dedup_log_dir}\n",
    "!mkdir -p {fuzzy_dedup_cache_dir}\n",
    "!mkdir -p {fuzzy_dedup_output_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb76d8e5",
   "metadata": {},
   "source": [
    "**[Optional]** If the cache folder is not empty, please CLEAR the folder before proceeding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "e7fb4c4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!rm -r {fuzzy_dedup_cache_dir}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "2368443f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n",
      "Stage1: Starting Minhash + LSH computation\n",
      "Stage1: Minhash + LSH complete!\n",
      "Stage2 (False Postive Check): Starting Map_Buckets\n",
      "Stage2 (False Postive Check): Map_Buckets Complete!\n",
      "Stage3 (False Postive Check): Shuffle docs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                                                                                                                                                                       | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Started processing bucket-map partitions 0 through 1 of 1\n",
      "Using 1 text partitions.\n",
      "Starting text bytes aware shuffle\n",
      "Will write 32059 rows to disk\n",
      "Text-df partition  1/1 completed in 2.764477491378784\n",
      "Bucket partition  1/1 completed in 2.783641815185547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stage3 (False Postive Check): Shuffle docs complete!\n",
      "Stage4 (False Postive Check): Jaccard Similarity in Buckets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stage4 (False Postive Check): Jaccard Similarity in Buckets Complete!\n",
      "Stage5: Connected Components across buckets\n",
      "batch_id = 0/1, time = 0.2485034465789795\n",
      "# of groups 5458\n",
      "# of docs removed 3086\n",
      "assert num_nodes:8544==labels_df:8544 passed\n",
      "Stage5: Connected Components across buckets complete!\n",
      "Writing to disk complete for 1 partitions\n",
      "Time taken for Connected Component: 20.06704068183899 s\n"
     ]
    }
   ],
   "source": [
    "with dask.config.set({\"dataframe.backend\": 'cudf'}):\n",
    "        \n",
    "        t0 = time.time()\n",
    "        \n",
    "        input_dataset = DocumentDataset.read_json(fuzzy_dedup_data_path, backend='cudf')\n",
    "\n",
    "        fuzzy_dedup_config = FuzzyDuplicatesConfig(\n",
    "            cache_dir=fuzzy_dedup_cache_dir,\n",
    "            id_field=id_field,\n",
    "            text_field=text_field,\n",
    "            seed=seed, #Use the seed set in Minhash section for consistency\n",
    "            char_ngrams=5,\n",
    "            num_buckets=20,\n",
    "            hashes_per_bucket=13,\n",
    "            use_64_bit_hash=False,\n",
    "            buckets_per_shuffle=5,\n",
    "            false_positive_check=True,\n",
    "            num_anchors=2,\n",
    "            jaccard_threshold=0.8,\n",
    "        )\n",
    "        fuzzy_dup = FuzzyDuplicates(logger=fuzzy_dedup_log_dir, config=fuzzy_dedup_config)\n",
    "        duplicates = fuzzy_dup(dataset=input_dataset)\n",
    "        \n",
    "        duplicates.to_parquet(fuzzy_dedup_output_dir, write_to_filename=False)\n",
    "       \n",
    "        print(f\"Time taken for Connected Component: {time.time()-t0} s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "14bfe3bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>group</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TH_wiki-0000134798</td>\n",
       "      <td>736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>TH_wiki-0000116226</td>\n",
       "      <td>1526</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>TH_wiki-0000126796</td>\n",
       "      <td>2934</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>TH_wiki-0000138218</td>\n",
       "      <td>156</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>TH_wiki-0000085437</td>\n",
       "      <td>2722</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   id  group\n",
       "0  TH_wiki-0000134798    736\n",
       "1  TH_wiki-0000116226   1526\n",
       "2  TH_wiki-0000126796   2934\n",
       "3  TH_wiki-0000138218    156\n",
       "4  TH_wiki-0000085437   2722"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fuzzy_dedup_res = pd.read_parquet(fuzzy_dedup_output_dir)\n",
    "fuzzy_dedup_res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2726cf9",
   "metadata": {},
   "source": [
    "## 6. Remove duplicates\n",
    "\n",
    "Now we have duplicated document IDs output by both exact deduplication and fuzzy deduplication. We will run this section to remove those documents. This is done be loading the output .parquet files and the unicode fixed input dataset in .jsonl as DataFrame. Then use DataFrame operation to remove the duplicated documents."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4dd78db",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "0027c8d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "dataset_dir = added_id_output_path\n",
    "\n",
    "#Output\n",
    "dudped_output_dir = os.path.join(data_dir,\"remove_duplicate/result.parquet\")\n",
    "\n",
    "#Relevant parameters\n",
    "input_id_field = 'id'\n",
    "id_prefix = add_ID_id_prefix\n",
    "\n",
    "!mkdir -p {dudped_output_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a373860d",
   "metadata": {},
   "source": [
    "We will first process the result of exact deduplication. Since result of exact deduplication contains original ID used in input dataset, it is more straightforward to deal with."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "f59e92c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n",
      "Reading 1 files\n"
     ]
    }
   ],
   "source": [
    "#Load .jsonl dataset\n",
    "input_dataset = DocumentDataset.read_json(dataset_dir, backend='cudf')\n",
    "\n",
    "#Load exact deduplicate result and extract list of duplicated document ID\n",
    "exact_duplicates = DocumentDataset.read_parquet(os.path.join(exact_dedup_output_dir,\"_exact_duplicates.parquet\"), backend='cudf')\n",
    "exact_docs_to_remove = exact_duplicates.df.map_partitions(\n",
    "    lambda x: x[x._hashes.duplicated(keep=\"first\")]\n",
    ")\n",
    "\n",
    "#Remove the duplicated document from input dataset\n",
    "result = input_dataset.df[\n",
    "    ~input_dataset.df[input_id_field].isin(exact_docs_to_remove[input_id_field].compute())\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f55d6737",
   "metadata": {},
   "source": [
    "For result of fuzzy deduplication, we need to first reconstructed document ID by combining `dataset_id` and `doc_id`, then use the reconstructed `ID` for removal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b9c122d",
   "metadata": {},
   "source": [
    "**[Optional]** Uncomment the cell to use result from step by step fuzzy deduplication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "c6a1bb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #List of id_prefix used in Add ID\n",
    "# base_ids = [id_prefix]\n",
    "\n",
    "# #Obtain a mapping between `dataset_id` and `id_prefix`\n",
    "# df = cudf.DataFrame()\n",
    "# df['base_id'] = [base_id for base_id in base_ids]\n",
    "# df['dataset_id'] = df['base_id'].hash_values()\n",
    "# df_pd = df.to_pandas()\n",
    "# mapping = {\n",
    "#       hashed_id: base_id\n",
    "#       for base_id, hashed_id in zip(df_pd['base_id'], df_pd['dataset_id'])\n",
    "# }\n",
    "\n",
    "# #Load result of fuzzy deduplication \n",
    "# fuzzy_duplicates = pd.read_parquet(connected_component_output_path)\n",
    "# #Reconstruct the original document ID\n",
    "# fuzzy_duplicates['id']=fuzzy_duplicates.apply(lambda x: f\"{mapping[x['dataset_id']]}-{x['doc_id']:010d}\", axis=1)\n",
    "\n",
    "# #Generate list of near duplicate document ID\n",
    "# fuzzy_docs_to_remove = fuzzy_duplicates.drop_duplicates(subset=['group'], keep='first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "746d3673",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Loads result from fuzzy dedup wrapper\n",
    "fuzzy_duplicates = pd.read_parquet(fuzzy_dedup_output_dir)\n",
    "\n",
    "#Generate list of near duplicate document ID\n",
    "fuzzy_docs_to_remove = fuzzy_duplicates.drop_duplicates(subset=['group'], keep='first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "62b34838",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Remove near duplicates\n",
    "result = result[~result[input_id_field].isin(fuzzy_docs_to_remove[input_id_field])]\n",
    "\n",
    "#Save final result to local\n",
    "result.to_parquet(dudped_output_dir, write_to_filename=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edfa52ce",
   "metadata": {},
   "source": [
    "Verify the result of duplicate removal. We can see that the number of document in resultant document is less than the original dataset (length = 161748)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "78eee9b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of duplicate removed dataset:156265\n"
     ]
    }
   ],
   "source": [
    "res = pd.read_parquet(dudped_output_dir)\n",
    "print(f\"Length of duplicate removed dataset:{len(res)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15e07a32",
   "metadata": {},
   "source": [
    "Close the GPU Dask Cluster.You might encounter error such as `Caught signal 11`.It's OK, just rerun the cell again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "8e807bd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "client.cluster.close()\n",
    "client.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a416a293",
   "metadata": {},
   "source": [
    "## 7. Heuristic Fitlering\n",
    "\n",
    "In this section, we will apply multiple heuristic filters to the dataset, record the heuristic score for documents and documents removed for each filter. For each heuristic filter, the filter calculates a quality scores based on user defined heuristics/algorithms and classifies documents into high quality documents or low quality documents if the quality score is above the user defined threshold.\n",
    "\n",
    "Sample lists of heuristic filters can be found in `./config/`\n",
    "- `heuristic_filter_en.yaml`: Sample heuristic filter list for English dataset\n",
    "- `heuristic_filter_non-en.yaml`:Sample heuristic filter list for Non-English dataset\n",
    "- `heuristic_filter_code.yaml`:Sample heuristic filter list for Code language dataset\n",
    "Please adjust the sample list e.g. remove/add filters or change filter threshold based on your own use case. In this example, `heuristic_filter_non-en.yaml` will be used.\n",
    "\n",
    "For detailed implementation and description of each heuristic filter, please refer to `./NeMo-Curator/nemo-curator/filters/heuristics_filter.py`. For customized heuristic filter implementation, user shall follow the sample implementations, write customized filters and update the .yaml files accordingly.\n",
    "\n",
    "For analysis of impact of each filters on the dataset, user should set `log-score` to true for the filters in the corresponding config .yaml file. This will output quality score for all filters in separate .txt files for each individual filter. With the quality score and filter threshold, use can calculate quality score distribution and other analysis to assess the effectiveness of each filter.\n",
    "\n",
    "In this example, in order to get a comprehensive output of each filter, we are iterating through ever filter using a for loop and saving the intermediate result. This process will involve extensive I/O operations and is less effective. Alternatively, after loading input dataset and filter pipeline, user can simply call `filter_pipeline(dataset)` to obtain the final filtered result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "b988ad1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo_curator.utils.config_utils import build_filter_pipeline\n",
    "from nemo_curator import Score, Filter, ScoreFilter\n",
    "from nemo_curator.utils.file_utils import get_batched_files,expand_outdir_and_mkdir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "097a1b48",
   "metadata": {},
   "source": [
    "**[Optional]** The following cell is to remove warning from dask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "44552288",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "# Disable the metadata warning\n",
    "warnings.filterwarnings(\"ignore\",module=\"dask.dataframe.core\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a59699d",
   "metadata": {},
   "source": [
    "Create a CPU Dask Cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "b8f80ab3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster = LocalCluster(n_workers=10, processes=True, memory_limit='16GB')\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7702918",
   "metadata": {},
   "source": [
    "Define some helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "6f2e7523",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataframe_complement(original_df, filtered_df):\n",
    "    def partition_complement(part_original_df, partition_info=None):\n",
    "        if not partition_info:\n",
    "            return part_original_df\n",
    "        part_filtered_df = filtered_df.get_partition(partition_info[\"number\"])\n",
    "        complement_mask = ~part_original_df.index.isin(part_filtered_df.index.persist())\n",
    "        complement_df = part_original_df[complement_mask]\n",
    "        return complement_df\n",
    "\n",
    "    return original_df.map_partitions(partition_complement)\n",
    "\n",
    "def write_scores(df, output_dir):\n",
    "    for column in df.columns:\n",
    "        output_path = os.path.join(output_dir, f\"{column}.txt\")\n",
    "        df[column].to_csv(output_path, single_file=True, encoding=\"utf-8\", header=False, index=False, mode=\"a\")\n",
    "\n",
    "def get_score_fields(pipeline):\n",
    "    score_fields = []\n",
    "    for nc_module in pipeline.modules:\n",
    "        if isinstance(nc_module, Score) or isinstance(nc_module, ScoreFilter):\n",
    "            if nc_module.score_field:\n",
    "                score_fields.append(nc_module.score_field)\n",
    "    return score_fields"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "227fa8b0",
   "metadata": {},
   "source": [
    "Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "a894f90f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input\n",
    "HF_input_data_dir = dudped_output_dir\n",
    "input_file_type = 'parquet'\n",
    "batch_size = 1\n",
    "\n",
    "#Output\n",
    "HF_base_output_path = os.path.join(data_dir,'heuristic_filtering')\n",
    "kept_document_dir =  os.path.join(HF_base_output_path,'data','hq.parquet')\n",
    "removed_document_dir =  os.path.join(HF_base_output_path,'data','lq.parquet')\n",
    "output_document_score_dir =  os.path.join(HF_base_output_path,'data','score')\n",
    "output_file_type = 'parquet'\n",
    "\n",
    "#Relevant parameters\n",
    "filter_config_file = './config/heuristic_filter_non-en.yaml'\n",
    "input_id_field = 'id'\n",
    "\n",
    "#Set to False if do not want to save intermediate results\n",
    "is_cache = True\n",
    "\n",
    "!mkdir -p {kept_document_dir}\n",
    "!mkdir -p {removed_document_dir}\n",
    "!mkdir -p {output_document_score_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccea406e",
   "metadata": {},
   "source": [
    "Run heuristic filtering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "03b3da27",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 1 files\n",
      "Saving data for symbol_to_word\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for numbers_ratio\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for urls_ratio\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for white_space\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for parentheses_ratio\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for boilerplate_string_ratio\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeated_lines\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeated_paragraphs\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeated_lines_char\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeated_paragraphs_char\n",
      "Writing to disk complete for 1 partitions\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/nemo_curator/utils/distributed_utils.py:379: UserWarning: Empty partition found\n",
      "  warnings.warn(f\"Empty partition found\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving data for word_count\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeating_top_2grams\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeating_top_3grams\n",
      "Writing to disk complete for 1 partitions\n",
      "Saving data for repeating_top_4grams\n",
      "Writing to disk complete for 1 partitions\n",
      "Writing to disk complete for 1 partitions\n",
      "Time taken for Heuristic filtering: 1120.5212895870209 s\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "#Load filters from config\n",
    "filter_pipeline = build_filter_pipeline(filter_config_file)\n",
    "score_fields = get_score_fields(filter_pipeline)\n",
    "\n",
    "# Load dataset\n",
    "dataset = DocumentDataset.read_parquet(HF_input_data_dir, backend='pandas', add_filename=True)\n",
    "\n",
    "\n",
    "# Iterate through filters. For each filter, the low quality document will be removed from the dataset and output to corresponding folder for analysis\n",
    "# Output of previous filter will be input of the next filter\n",
    "if is_cache:\n",
    "    curr_dataset = prev_dataset = dataset\n",
    "    for filter_module in filter_pipeline.modules:\n",
    "        #Apply filter\n",
    "        curr_dataset = filter_module(curr_dataset).persist()\n",
    "\n",
    "        #Output filtered document\n",
    "        print(f\"Saving data for {filter_module.filter_obj._name}\")\n",
    "        removed_df = get_dataframe_complement(prev_dataset.df, curr_dataset.df)\n",
    "        removed_filter_dir = os.path.join(removed_document_dir, filter_module.filter_obj._name)\n",
    "        expand_outdir_and_mkdir(removed_filter_dir)\n",
    "        write_to_disk(removed_df, removed_filter_dir, write_to_filename=True, output_type=output_file_type)\n",
    "        prev_dataset = curr_dataset\n",
    "    filtered_dataset = curr_dataset\n",
    "else:\n",
    "    filtered_dataset = filter_pipeline(dataset)\n",
    "\n",
    "# Write scores of retained doucment to separate directory\n",
    "output_df = filtered_dataset.df[[input_id_field, *score_fields]]\n",
    "write_scores(output_df, output_document_score_dir)\n",
    "\n",
    "# Remove scores from dataset df\n",
    "filtered_dataset = DocumentDataset(filtered_dataset.df.drop(columns=score_fields))\n",
    "\n",
    "# Output filtered dataset\n",
    "filtered_dataset.to_parquet(kept_document_dir, write_to_filename=True)\n",
    "\n",
    "print(f\"Time taken for Heuristic filtering: {time.time()-t0} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a53b04e9",
   "metadata": {},
   "source": [
    "Verify the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "07475373",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset size after heuristic filtering:192786\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>id</th>\n",
       "      <th>language</th>\n",
       "      <th>source_id</th>\n",
       "      <th>text</th>\n",
       "      <th>title</th>\n",
       "      <th>url</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>part.0.parquet</td>\n",
       "      <td>TH_wiki-0000000001</td>\n",
       "      <td>TH</td>\n",
       "      <td>thwiki-20240201-thwiki-20240201-pages-articles...</td>\n",
       "      <td>ดาราศาสตร์ คือวิชาวิทยาศาสตร์ที่ศึกษาวัตถุในท้...</td>\n",
       "      <td>ดาราศาสตร์</td>\n",
       "      <td>https://th.wikipedia.org/wiki/%E0%B8%94%E0%B8%...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>part.0.parquet</td>\n",
       "      <td>TH_wiki-0000000002</td>\n",
       "      <td>TH</td>\n",
       "      <td>thwiki-20240201-thwiki-20240201-pages-articles...</td>\n",
       "      <td>ภูมิศาสตร์ (,  แปลว่า \"การพรรณนาเกี่ยวกับโลก\")...</td>\n",
       "      <td>ภูมิศาสตร์</td>\n",
       "      <td>https://th.wikipedia.org/wiki/%E0%B8%A0%E0%B8%...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>part.0.parquet</td>\n",
       "      <td>TH_wiki-0000000003</td>\n",
       "      <td>TH</td>\n",
       "      <td>thwiki-20240201-thwiki-20240201-pages-articles...</td>\n",
       "      <td>พันทิป.คอม หรือพันทิป ก่อตั้งขึ้นเมื่อวันที่ 7...</td>\n",
       "      <td>พันทิป.คอม</td>\n",
       "      <td>https://th.wikipedia.org/wiki/%E0%B8%9E%E0%B8%...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>part.0.parquet</td>\n",
       "      <td>TH_wiki-0000000004</td>\n",
       "      <td>TH</td>\n",
       "      <td>thwiki-20240201-thwiki-20240201-pages-articles...</td>\n",
       "      <td>พันธุ์ทิพย์พลาซ่า () เป็นศูนย์การค้าเกี่ยวกับเ...</td>\n",
       "      <td>พันธุ์ทิพย์พลาซ่า</td>\n",
       "      <td>https://th.wikipedia.org/wiki/%E0%B8%9E%E0%B8%...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>part.0.parquet</td>\n",
       "      <td>TH_wiki-0000000005</td>\n",
       "      <td>TH</td>\n",
       "      <td>thwiki-20240201-thwiki-20240201-pages-articles...</td>\n",
       "      <td>วิทยาการคอมพิวเตอร์ศึกษาเกี่ยวกับโครงสร้างพื้น...</td>\n",
       "      <td>วิทยาการคอมพิวเตอร์</td>\n",
       "      <td>https://th.wikipedia.org/wiki/%E0%B8%A7%E0%B8%...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         filename                  id language  \\\n",
       "1  part.0.parquet  TH_wiki-0000000001       TH   \n",
       "2  part.0.parquet  TH_wiki-0000000002       TH   \n",
       "3  part.0.parquet  TH_wiki-0000000003       TH   \n",
       "4  part.0.parquet  TH_wiki-0000000004       TH   \n",
       "5  part.0.parquet  TH_wiki-0000000005       TH   \n",
       "\n",
       "                                           source_id  \\\n",
       "1  thwiki-20240201-thwiki-20240201-pages-articles...   \n",
       "2  thwiki-20240201-thwiki-20240201-pages-articles...   \n",
       "3  thwiki-20240201-thwiki-20240201-pages-articles...   \n",
       "4  thwiki-20240201-thwiki-20240201-pages-articles...   \n",
       "5  thwiki-20240201-thwiki-20240201-pages-articles...   \n",
       "\n",
       "                                                text                title  \\\n",
       "1  ดาราศาสตร์ คือวิชาวิทยาศาสตร์ที่ศึกษาวัตถุในท้...           ดาราศาสตร์   \n",
       "2  ภูมิศาสตร์ (,  แปลว่า \"การพรรณนาเกี่ยวกับโลก\")...           ภูมิศาสตร์   \n",
       "3  พันทิป.คอม หรือพันทิป ก่อตั้งขึ้นเมื่อวันที่ 7...           พันทิป.คอม   \n",
       "4  พันธุ์ทิพย์พลาซ่า () เป็นศูนย์การค้าเกี่ยวกับเ...    พันธุ์ทิพย์พลาซ่า   \n",
       "5  วิทยาการคอมพิวเตอร์ศึกษาเกี่ยวกับโครงสร้างพื้น...  วิทยาการคอมพิวเตอร์   \n",
       "\n",
       "                                                 url  \n",
       "1  https://th.wikipedia.org/wiki/%E0%B8%94%E0%B8%...  \n",
       "2  https://th.wikipedia.org/wiki/%E0%B8%A0%E0%B8%...  \n",
       "3  https://th.wikipedia.org/wiki/%E0%B8%9E%E0%B8%...  \n",
       "4  https://th.wikipedia.org/wiki/%E0%B8%9E%E0%B8%...  \n",
       "5  https://th.wikipedia.org/wiki/%E0%B8%A7%E0%B8%...  "
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = pd.read_parquet(kept_document_dir)\n",
    "print(f\"Dataset size after heuristic filtering:{len(res)}\")\n",
    "res.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24e8b173",
   "metadata": {},
   "source": [
    "Close the CPU Dask Cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "12508f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "client.cluster.close()\n",
    "client.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e4aed1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
